
import os
import numpy as np
from src.printing.pixels import print_matrix_pixels
from src.clustering.utils import get_average_support


class Summary:

    def __init__(self, meta_profile, voters_mapping, candidates_mapping):
        self.meta_profile = meta_profile
        self.voters_mapping = voters_mapping
        self.candidates_mapping = candidates_mapping

        self.num_v_clusters = meta_profile.shape[0]
        self.num_c_clusters = meta_profile.shape[1]
        self.num_voters = len(voters_mapping)
        self.num_candidates = len(candidates_mapping)

        self.v_clusters_weights = None
        self.c_clusters_weights = None
        self.compute_weights()

        self.v_sets = None
        self.c_sets = None
        self.compute_sets()

        self.expanded_summary = None
        # self.expand()

    def compute_weights(self):
        self.v_clusters_weights = np.zeros(self.num_v_clusters)
        self.c_clusters_weights = np.zeros(self.num_c_clusters)

        for i in range(self.num_voters):
            self.v_clusters_weights[self.voters_mapping[i]] += 1
        for i in range(self.num_candidates):
            self.c_clusters_weights[self.candidates_mapping[i]] += 1

    def compute_sets(self):
        self.v_sets = [set() for _ in range(self.num_v_clusters)]
        self.c_sets = [set() for _ in range(self.num_c_clusters)]

        for i in range(self.num_voters):
            self.v_sets[self.voters_mapping[i]].add(i)
        for i in range(self.num_candidates):
            self.c_sets[self.candidates_mapping[i]].add(i)

    def print_meta_profile(self, **kwargs):
        print_matrix_pixels(self.meta_profile, **kwargs)

    def print_expanded_summary(self, **kwargs):
        if self.expanded_summary is None:
            self.expand()
        print_matrix_pixels(self.expanded_summary, **kwargs)

    def print_meta_profile_with_cluster_sizes(self, **kwargs):
        self.print_meta_profile(
            cluster_sizes=[self.v_clusters_weights, self.c_clusters_weights],
            # bbox_inches=None,
            **kwargs
        )

    def expand(self):
        # convert voters_mapping to a 2d numpy matrix
        np_voter_mapping = np.zeros([self.num_voters, self.num_v_clusters])
        for i, value in enumerate(self.voters_mapping):
            np_voter_mapping[i][value] = 1

        np_candidate_mapping = np.zeros([self.num_candidates, self.num_c_clusters])
        for i, value in enumerate(self.candidates_mapping):
            np_candidate_mapping[i][value] = 1

        E = np.matmul(np_voter_mapping, self.meta_profile)
        E = np.matmul(E, np.transpose(np_candidate_mapping))
        return E

    def expand_visually(self):
        expanded_summary = np.repeat(self.meta_profile, np.array(self.v_clusters_weights, dtype=int), axis=0)
        expanded_summary = np.repeat(expanded_summary, np.array(self.c_clusters_weights, dtype=int), axis=1)
        self.expanded_summary = expanded_summary

        return expanded_summary

    def export(self, filename):

        if not os.path.exists("results"):
            os.makedirs("results")

        np.savetxt(f"results/{filename}.csv", self.meta_profile, fmt='%.5f', delimiter=",")

    def _v_cluster_rename(self):

        # Calculate average positions
        average_positions = calculate_average_positions(self.voters_mapping, self.num_v_clusters)

        # Sort clusters by their average positions
        sorted_clusters = sorted(average_positions, key=average_positions.get)

        # Remapping based on average positions
        new_mapping = {sorted_clusters[i]: i for i in range(self.num_v_clusters)}

        # Update voters_mapping
        self.voters_mapping = np.array([new_mapping[label] for label in self.voters_mapping])

        # Update meta_profile
        self.meta_profile = self.meta_profile[[new_mapping[label]
                                               for label in range(self.num_v_clusters)], :]

    def _c_cluster_rename(self):

        # Calculate average positions
        average_positions = calculate_average_positions(self.candidates_mapping, self.num_c_clusters)

        # Sort clusters by their average positions
        sorted_clusters = sorted(average_positions, key=average_positions.get)

        # Remapping based on average positions
        new_mapping = {sorted_clusters[i]: i for i in range(self.num_c_clusters)}

        # Update voters_mapping
        self.candidates_mapping = np.array([new_mapping[label] for label in self.candidates_mapping])

        # Update meta_profile
        self.meta_profile = self.meta_profile[:, [new_mapping[label]
                                                  for label in range(self.num_c_clusters)]]

    def cluster_rename(self):
        self._v_cluster_rename()
        self._c_cluster_rename()
        self.compute_weights()
        self.compute_sets()


def calculate_average_positions(lst, num_clusters):
    # Function to calculate average position of each cluster
    positions = {i: [] for i in range(num_clusters)}
    for idx, val in enumerate(lst):
        positions[val].append(idx)

    # Calculate average position for each cluster
    average_positions = {k: sum(v) / len(v) if v else float('inf') for k, v in
                         positions.items()}
    return average_positions


def _get_total_length(dict_of_sets):
    return len(set().union(*dict_of_sets.values()))


def _convert_sets_to_mapping(sets):
    # convert a dictionary of sets to a mapping
    length = _get_total_length(sets)
    mapping = [None for _ in range(length)]
    for i, s in enumerate(sets):
        for j in sets[s]:
            mapping[int(j)] = i
    return mapping


def get_summary_from_sets(profile, v_sets, c_sets):
    num_v_clusters = len(v_sets)
    num_c_clusters = len(c_sets)
    v_mapping = _convert_sets_to_mapping(v_sets)
    c_mapping = _convert_sets_to_mapping(c_sets)

    avg_support = get_average_support(profile, v_mapping, c_mapping, num_v_clusters, num_c_clusters)

    return Summary(avg_support, v_mapping, c_mapping)